昨天訓練資料時遇到 RuntimeError: output with shape [1, 60, 60] doesn't match the broadcast shape [3, 60, 60]
的錯誤,因此我猜測是圖片有問題,今天寫了一個程式來看看資料集哪裡出錯了,首先我們引用需要用到的模組:
import pandas as pd
import numpy as np
import os
from PIL import Image
這邊與之前相同,把所有圖片的路徑讀取進來:
fasion_df = pd.read_csv("./fashion_product_images_small/myntradataset/styles.csv", on_bad_lines='skip')
fasion_df['image'] = fasion_df.apply(lambda row: str(row['id']) + ".jpg", axis=1)
fasion_df = fasion_df.reset_index(drop=True)
image_name = fasion_df['image'].to_numpy()
print("圖片總數: ", len(image_name))
image_path = [os.path.join("./fashion_product_images_small/myntradataset/images",i) for i in image_name ]
此處我們用到 PIL.Image
中的 getbands()
方法把圖片的 Channels 顯示出來:
img_pil = Image.open(image_path[0])
print(img_pil.getbands())
print(len(img_pil.getbands()))
輸出結果:
('R', 'G', 'B')
3
接著我寫了一個迴圈讀取所有資料集的圖片,第一個迴圈if os.path.isfile(img_path):
是判斷圖片的路徑有沒有存在,第二個是判斷 channels 數是不是 RGB 長度為3 if len(img_pil.getbands()) != 3:
for img_path in image_path:
if os.path.isfile(img_path):
img_pil = Image.open(img_path)
if len(img_pil.getbands()) != 3:
print("image channels not RGB: ", img_path)
else:
print("image path not exist: ", img_path)
輸出結果可以看到有很多 400 多筆圖片不是RGB,且有 5 筆圖片不存在,以下為截圖片段:
明天我們要針對這些例外來修正 DataLoader ,修正好之後再開始訓練模型。